{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Memory Augmented Neural Network: Simple Illustration\n",
    "\n",
    "As we showcased above that a Neural Turing Machine's controller is capable of using content-based addressing, location-based addressing or both. Whereas, here MANN works on using a pure content-based memory writer. \n",
    "\n",
    "MANN also use a new addressing schema called least recently used access. The idea behind the scene is that the least recently used memory location is determined by the read operation and the read operation is performed by content-based addressing. So, we basically perform content-based addressing for reading and write to the location that was least recently used.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "[picture credits: MANN Paper(https://arxiv.org/pdf/1605.06065.pdf)]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"Images/mann.png\" width=\"1500\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we will do following things step by step:\n",
    "1. Implement Read Operation\n",
    "2. Implement Write Operation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Step1: Lets first import all libraries needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import copy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Step 2: Implement Memory Module Similar to NTM Method but with some changes "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "<img src=\"Images/MANN_read.png\" width=\"500\"/>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Memory(nn.Module):\n",
    "    def __init__(self, M, N, controller_out):\n",
    "        super(Memory, self).__init__()\n",
    "        self.N = N\n",
    "        self.M = M\n",
    "        self.read_lengths = self.N + 1 + 1 + 3 + 1\n",
    "        self.write_lengths = self.N + 1 + 1 + 3 + 1 + self.N + self.N\n",
    "        self.w_last = [] # define to keep track of weight_vector at each time step.\n",
    "        self.reset_memory()\n",
    "\n",
    "    def get_weights(self):\n",
    "        return self.w_last\n",
    "\n",
    "    def reset_memory(self):\n",
    "        # resets the memory for both read and write operations at start of new sequence (new input)\n",
    "        self.w_last = []\n",
    "        self.w_last.append(torch.zeros([1, self.M], dtype=torch.float32))\n",
    "\n",
    "    def address(self, k, beta, g, s, gamma, memory, w_last):\n",
    "        # Content focus\n",
    "        w_r = self._similarity(k, beta, memory)\n",
    "        return w_r\n",
    "\n",
    "    # Implementing Similarity\n",
    "    def _similarity(self, k, beta, memory):\n",
    "        w = F.cosine_similarity(memory, k, -1, 1e-16) \n",
    "        w = F.softmax(w, dim=-1)\n",
    "        return w # return w_r^t for reading purpose\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 2: Implementing Read Operation\n",
    "Here, We will define read heads which access memory and updates memory according to read operations we discuss in chapter above.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class ReadHead(Memory):\n",
    "\n",
    "    def __init__(self, M, N, controller_out):\n",
    "        super(ReadHead, self).__init__(M, N, controller_out)\n",
    "        self.fc_read = nn.Linear(controller_out, self.read_lengths)\n",
    "        global w_read\n",
    "        self.intialize_parameters();\n",
    "\n",
    "    def intialize_parameters(self):\n",
    "        # Initialize the linear layers\n",
    "        nn.init.xavier_uniform_(self.fc_read.weight, gain=1.4)\n",
    "        nn.init.normal_(self.fc_read.bias, std=0.01)\n",
    "\n",
    "    def read(self, memory, w):\n",
    "        # Calculate Memory Update\n",
    "        return torch.matmul(w, memory)\n",
    "\n",
    "    def forward(self, x, memory):\n",
    "        param = self.fc_read(x) # gather parameters\n",
    "        # initialize necessary parameters k, beta, g, shift, and gamma\n",
    "        k, g, s, gamma = torch.split(param, [self.N, 1, 1, 3, 1], dim=1)\n",
    "        k = torch.tanh(k)\n",
    "        g = F.sigmoid(g)\n",
    "        s = F.softmax(s, dim=1)\n",
    "        gamma = 1 + F.softplus(gamma)\n",
    "        # obtain current weight address vectors from Memory\n",
    "        w_r = self.address(k, g, s, gamma, memory, self.w_last[-1])\n",
    "        # append in w_last function to keep track content based locations\n",
    "        self.w_last.append(w_r)\n",
    "        # obtain current mem location based on above equations\n",
    "        mem = self.read(memory, w_r)\n",
    "        w_read = copy.deepcopy(w_r)\n",
    "        return mem, w_r"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Step 5: Implement Write Operation\n",
    "Similar to Read Operation, here we will implement write operation.\n",
    "\n",
    "Note: Both read and write heads use fully connected layer to produce paremeters (k, beta, g, s, gamma) for content addressing. \n",
    "\n",
    "<img src=\"Images/MANN_write.png\" width=\"500\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class WriteHead(Memory):\n",
    "\n",
    "    def __init__(self, M, N, controller_out):\n",
    "        super(WriteHead, self).__init__(M, N, controller_out)\n",
    "        self.fc_write = nn.Linear(controller_out, self.write_lengths)\n",
    "        global w_write\n",
    "        global prev_w_u\n",
    "        self.intialize_parameters()\n",
    "\n",
    "    def intialize_parameters(self):\n",
    "        # Initialize the linear layers\n",
    "        nn.init.xavier_uniform_(self.fc_write.weight, gain=1.4)\n",
    "        nn.init.normal_(self.fc_write.bias, std=0.01)\n",
    "        prev_w_u = torch.FloatTensor(np.zeros((batch_size,memory_size)))\n",
    "        w_write = torch.FloatTensor(np.zeros((1,memory_size)))\n",
    "    \n",
    "    def usage_weight_vector(self, prev_w_u, w_read, w_write, gamma):\n",
    "        # usage weight vector, Equation (F2)\n",
    "        w_u = gamma * prev_w_u + torch.sum(w_read, dim=1) + torch.sum(w_write, dim=1)\n",
    "        return w_u \n",
    "    \n",
    "    def least_used(self, w_u, memory_size=3, n_reads=4):\n",
    "        # calculate the least used entries\n",
    "        _, indices = torch.topk(-1*w_u,k=n_reads)\n",
    "        wlu_t = torch.sum(F.one_hot(indices, memory_size).type(torch.FloatTensor),dim=1,keepdim=True)\n",
    "        return indices, wlu_t\n",
    "    \n",
    "    def mann_write(self, memory, w_write, a, gamma, prev_w_u, w_read, k):\n",
    "        # obtain the current usage weight vector\n",
    "        w_u = self.usage_weight_vector(prev_w_u, w_read, w_write, gamma)\n",
    "        # Calculate the least used usage weight vector\n",
    "        w_least_used_weight_t = self.least_used(w_u)\n",
    "        # Implement write step as per (F3) Equation\n",
    "        w_write = torch.sigmoid(a)*w_read + (1-torch.sigmoid(a))*w_least_used_weight_t\n",
    "        # Memory update as per Equation (F4)\n",
    "        memory_update = memory + w_write*k\n",
    "    \n",
    "        \n",
    "    def forward(self, x, memory):\n",
    "        param = self.fc_write(x) # gather parameters\n",
    "         # initialize necessary parameters k, beta, g, shift, and gamma\n",
    "        k, beta, g, s, gamma, a, e = torch.split(param, [self.N, 1, 1, 3, 1, self.N, self.N], dim=1)\n",
    "        k = F.tanh(k)\n",
    "        beta = F.softplus(beta)\n",
    "        g = F.sigmoid(g)\n",
    "        s = F.softmax(s, dim=-1)\n",
    "        gamma = 1 + F.softplus(gamma)\n",
    "        a = F.tanh(a)\n",
    "        # obtain current weight address vectors from Memory\n",
    "        w_write = self.address(k, beta, g, s, gamma, memory, self.w_last[-1])\n",
    "        # append in w_last function to keep track content based locations\n",
    "        self.w_last.append(w_write)\n",
    "        # obtain current mem location based on F2-F4 equations\n",
    "        mem = self.write(memory, w_write, a, gamma, prev_w_u, w_read, k)\n",
    "        w_write = copy.deepcopy(w)\n",
    "        return mem, w\n"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
